{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification on Iris dataset with sklearn and DJL\n",
    "\n",
    "In this notebook, you will try to use a pre-trained sklearn model to run on DJL for a general classification task. The model was trained with [Iris flower dataset](https://en.wikipedia.org/wiki/Iris_flower_data_set).\n",
    "\n",
    "## Background \n",
    "\n",
    "### Iris Dataset\n",
    "\n",
    "The dataset contains a set of 150 records under five attributes - sepal length, sepal width, petal length, petal width and species.\n",
    "\n",
    "Iris setosa                |  Iris versicolor          | Iris virginica\n",
    ":-------------------------:|:-------------------------:|:-------------------------:\n",
    "![](https://upload.wikimedia.org/wikipedia/commons/5/56/Kosaciec_szczecinkowaty_Iris_setosa.jpg)  |  ![](https://upload.wikimedia.org/wikipedia/commons/4/41/Iris_versicolor_3.jpg) | ![](https://upload.wikimedia.org/wikipedia/commons/9/9f/Iris_virginica.jpg) \n",
    "\n",
    "The chart above shows three different kinds of the Iris flowers. \n",
    "\n",
    "We will use sepal length, sepal width, petal length, petal width as the feature and species as the label to train the model.\n",
    "\n",
    "### Sklearn Model\n",
    "\n",
    "You can find more information [here](http://onnx.ai/sklearn-onnx/). You can use the sklearn built-in iris dataset to load the data. Then we defined a [RandomForestClassifer](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html) to train the model. After that, we convert the model to onnx format for DJL to run inference. The following code is a sample classification setup using sklearn:\n",
    "\n",
    "```python\n",
    "# Train a model.\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "iris = load_iris()\n",
    "X, y = iris.data, iris.target\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y)\n",
    "clr = RandomForestClassifier()\n",
    "clr.fit(X_train, y_train)\n",
    "```\n",
    "\n",
    "\n",
    "## Preparation\n",
    "\n",
    "This tutorial requires the installation of Java Kernel. To install the Java Kernel, see the [README](https://github.com/awslabs/djl/blob/master/jupyter/README.md).\n",
    "\n",
    "These are dependencies we will use. To enhance the NDArray operation capability, we are importing ONNX Runtime and PyTorch Engine at the same time. Please find more information [here](https://github.com/awslabs/djl/blob/master/docs/onnxruntime/hybrid_engine.md#hybrid-engine-for-onnx-runtime)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.8.0\n",
    "%maven ai.djl.onnxruntime:onnxruntime-engine:0.8.0\n",
    "%maven ai.djl.pytorch:pytorch-engine:0.8.0\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "\n",
    "%maven com.microsoft.onnxruntime:onnxruntime:1.4.0\n",
    "%maven ai.djl.pytorch:pytorch-native-auto:1.6.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.inference.*;\n",
    "import ai.djl.modality.*;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.ndarray.types.*;\n",
    "import ai.djl.repository.zoo.*;\n",
    "import ai.djl.translate.*;\n",
    "import java.util.*;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1 create a Translator\n",
    "\n",
    "Inference in machine learning is the process of predicting the output for a given input based on a pre-defined model.\n",
    "DJL abstracts away the whole process for ease of use. It can load the model, perform inference on the input, and provide\n",
    "output. DJL also allows you to provide user-defined inputs. The workflow looks like the following:\n",
    "\n",
    "![https://github.com/awslabs/djl/blob/master/examples/docs/img/workFlow.png?raw=true](https://github.com/awslabs/djl/blob/master/examples/docs/img/workFlow.png?raw=true)\n",
    "\n",
    "The `Translator` interface encompasses the two white blocks: Pre-processing and Post-processing. The pre-processing\n",
    "component converts the user-defined input objects into an NDList, so that the `Predictor` in DJL can understand the\n",
    "input and make its prediction. Similarly, the post-processing block receives an NDList as the output from the\n",
    "`Predictor`. The post-processing block allows you to convert the output from the `Predictor` to the desired output\n",
    "format.\n",
    "\n",
    "In our use case, we use a class namely `FlowerInfo` as our input class type. We will use [`Classifications`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/modality/Classifications.html) as our output class type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class FlowerInfo {\n",
    "\n",
    "    public float sepalLength;\n",
    "    public float sepalWidth;\n",
    "    public float petalLength;\n",
    "    public float petalWidth;\n",
    "\n",
    "    public FlowerInfo(float sepalLength, float sepalWidth, float petalLength, float petalWidth) {\n",
    "        this.sepalLength = sepalLength;\n",
    "        this.sepalWidth = sepalWidth;\n",
    "        this.petalLength = petalLength;\n",
    "        this.petalWidth = petalWidth;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's create a translator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static class MyTranslator implements Translator<FlowerInfo, Classifications> {\n",
    "\n",
    "    private final List<String> synset;\n",
    "\n",
    "    public MyTranslator() {\n",
    "        // species name\n",
    "        synset = Arrays.asList(\"setosa\", \"versicolor\", \"virginica\");\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public NDList processInput(TranslatorContext ctx, FlowerInfo input) {\n",
    "        float[] data = {input.sepalLength, input.sepalWidth, input.petalLength, input.petalWidth};\n",
    "        NDArray array = ctx.getNDManager().create(data, new Shape(1, 4));\n",
    "        return new NDList(array);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Classifications processOutput(TranslatorContext ctx, NDList list) {\n",
    "        return new Classifications(synset, list.get(1));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Batchifier getBatchifier() {\n",
    "        return null;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2 Prepare your model\n",
    "\n",
    "We will load a pretrained sklearn model into DJL. We defined a [`ModelZoo`](https://javadoc.io/doc/ai.djl/api/latest/ai/djl/repository/zoo/ModelZoo.html) concept to allow user load model from varity of locations, such as remote URL, local files or DJL pretrained model zoo. We need to define `Criteria` class to help the modelzoo locate the model and attach translator. In this example, we download a compressed ONNX model from S3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String modelUrl = \"https://alpha-djl-demos.s3.amazonaws.com/model/sklearn/rf_iris.zip\";\n",
    "Criteria<FlowerInfo, Classifications> criteria = Criteria.builder()\n",
    "        .setTypes(FlowerInfo.class, Classifications.class)\n",
    "        .optModelUrls(modelUrl)\n",
    "        .optTranslator(new MyTranslator())\n",
    "        .optEngine(\"OnnxRuntime\") // use OnnxRuntime engine by default\n",
    "        .build();\n",
    "ZooModel<FlowerInfo, Classifications> model = ModelZoo.loadModel(criteria);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3 Run inference\n",
    "\n",
    "User will just need to create a `Predictor` from model to run the inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Predictor<FlowerInfo, Classifications> predictor = model.newPredictor();\n",
    "FlowerInfo info = new FlowerInfo(1.0f, 2.0f, 3.0f, 4.0f);\n",
    "predictor.predict(info);"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "12.0.2+10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
